{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "WKchP4VBBRgq"
   },
   "source": [
    "# Colocated Python\n",
    "\n",
    "NOTE: Colocated Python is currently an experimental API. Its functionality and\n",
    "interface are subject to change without following the standard JAX compatibility\n",
    "policy.\n",
    "\n",
    "Colocated Python provides a uniform way to run Python code on the hosts\n",
    "associated with a set of JAX devices. If the JAX devices represent local\n",
    "devices, the Python code will run on the local host. If the JAX devices\n",
    "represent remote devices, the Python code will be shipped to run on the host of\n",
    "these remote devices. This is useful when building a multi-host ML system on top\n",
    "of JAX that is portable across multi-controller JAX environments (running JAX\n",
    "code on each host with accelerators) as well as single-controller JAX\n",
    "environments (running JAX code on a single host orchestrating other hosts with\n",
    "accelerators)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B38uuH1ZBZmd"
   },
   "source": [
    "## Colocated CPU devices\n",
    "\n",
    "To use colocated Python, the first step is to obtain CPU devices colocated with\n",
    "target accelerator devices.\n",
    "`jax.experimental.colocated_python.colocated_cpu_devices` provides a standard\n",
    "way to do so."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "d7FHtd4wCYEf"
   },
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.experimental.colocated_python as colocated_python\n",
    "\n",
    "devices = jax.devices()\n",
    "cpu_devices = colocated_python.colocated_cpu_devices(devices)\n",
    "print(cpu_devices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Grfb7H4FCVsE"
   },
   "source": [
    "As usual, the CPU devices can be used with JAX APIs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "5RmWK-s4DQsl"
   },
   "outputs": [],
   "source": [
    "cpu_mesh = jax.sharding.Mesh(cpu_devices, [\"x\"])\n",
    "cpu_sharding = jax.sharding.NamedSharding(cpu_mesh, jax.P())\n",
    "x = jax.device_put(1, cpu_sharding)\n",
    "y = jax.jit(lambda x: x + 1)(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7U1OScHaCjSC"
   },
   "source": [
    "## Colocated Python function\n",
    "\n",
    "CPU devices can also be used to run Python code with colocated Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "PJbdHF8mDZNT"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  return x + 1\n",
    "\n",
    "\n",
    "f = colocated_python.colocated_python(f)\n",
    "y = f(x)\n",
    "assert y.sharding == x.sharding\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tpGdXqG9C5X3"
   },
   "source": [
    "Since colocated Python runs normal Python code, you can also perform I/O:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "MeWnKNlHDgs3"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  with open('/tmp/foo', 'w') as f:\n",
    "    f.write(str(x))\n",
    "  return x\n",
    "\n",
    "\n",
    "f = colocated_python.colocated_python(f)\n",
    "jax.block_until_ready(f(x))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HOGQQ5IUC7Pe"
   },
   "source": [
    "Note the use of `jax.block_until_ready` to ensure the Python code has\n",
    "completed. In principle, colocated Python calls may run asynchronously, similar\n",
    "to jitted function calls; the calls would return JAX arrays and do not block\n",
    "until their output is produced. Thus, you should block on an output from a\n",
    "colocated Python call if the completion of the execution is significant.\n",
    "\n",
    "There exist cases where a colocated Python call runs synchronously.\n",
    "\n",
    "* If the colocated Python function is called without \"specialization\" (see\n",
    "  below), the very first call will run synchronously. This is because the shape\n",
    "  and sharding of the output must be known for asynchronous execution, and\n",
    "  colocated Python has to run the Python code once to discover this information.\n",
    "\n",
    "* Some JAX backends do not yet fully support asynchronous execution, and will\n",
    "  fall back to synchronous execution.\n",
    "\n",
    "The wrapped Python code must use exactly the same set of devices in the input\n",
    "and the output. This is a requirement similar to jitted functions that represent\n",
    "an SPMD execution."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "uX8q-42tC8ia"
   },
   "source": [
    "## Specialization\n",
    "\n",
    "Specialization in colocated Python is a mechanism to supply extra information\n",
    "about the input, output, and execution of a colocated Python function, when the\n",
    "information cannot be inferred in advance, or you would like to ensure the\n",
    "colocated Python executions to happen precisely as specified.\n",
    "\n",
    "First, functions wrapped in colocated Python has a `specialize` method.\n",
    "This method is used to create another colocated Python wrapped function\n",
    "specialized with the supplied information.\n",
    "\n",
    "`out_specs_fn` is a function that takes a pytree of\n",
    "`jax.ShapeDtypeStruct` of the call inputs and returns a pytree of\n",
    "`jax.ShapeDtypeStruct` expected for the output. Calling this function is\n",
    "analogous to jitted function tracing, but this function is separate from the\n",
    "original Python code. This function runs on the caller side and not executed on\n",
    "the devices."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "SWEuz68nDtXE"
   },
   "outputs": [],
   "source": [
    "def f(x):\n",
    "  return x + 1\n",
    "\n",
    "\n",
    "f = colocated_python.colocated_python(f)\n",
    "f = f.specialize(out_specs_fn=lambda x: x)\n",
    "y = f(x)\n",
    "assert y.sharding == x.sharding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HkQZwqUBC-QV"
   },
   "source": [
    "`in_specs` takes a concrete pytree (the top level is tuple) of\n",
    "`jax.sharding.ShapeDtypeStruct` expected for the input to the colocated\n",
    "Python function call. This is used if a certain input spec must be used, or the\n",
    "output specs function can be computed only for a concrete input spec."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "E0SQPPHID1WU"
   },
   "outputs": [],
   "source": [
    "import jax.numpy as jnp\n",
    "\n",
    "\n",
    "def f(x):\n",
    "  return x + 1\n",
    "\n",
    "\n",
    "f = colocated_python.colocated_python(f)\n",
    "f = f.specialize(\n",
    "    in_specs=(\n",
    "        # args\n",
    "        (\n",
    "            jax.ShapeDtypeStruct(\n",
    "                shape=(), dtype=jnp.int32, sharding=cpu_sharding\n",
    "            ),\n",
    "        ),\n",
    "        # kwargs\n",
    "        {},\n",
    "    ),\n",
    "    out_specs_fn=lambda x: jax.ShapeDtypeStruct(\n",
    "        shape=(), dtype=jnp.int32, sharding=cpu_sharding\n",
    "    ),\n",
    ")\n",
    "f(x)  # `x` must match the input spec."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2L7aUBvsC_4m"
   },
   "source": [
    "`devices` specifies a list of devices that the colocated Python function\n",
    "should run on. Having `devices` specialized lets a colocated Python function\n",
    "without input arguments run."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "ZwWQRm_PDAll"
   },
   "outputs": [],
   "source": [
    "def f():\n",
    "  with open('/tmp/foo', 'w') as f:\n",
    "    f.write('foo')\n",
    "  return\n",
    "\n",
    "\n",
    "f = colocated_python.colocated_python(f)\n",
    "f = f.specialize(devices=cpu_devices)\n",
    "f()  # Would be an error if `f` is not specialized with ``devices``."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xIjM-au9DBQL"
   },
   "source": [
    "## Colocated Python class\n",
    "\n",
    "Colocated Python also supports wrapping Python classes. A real instance is\n",
    "created on the hosts associated with the devices, and the caller side will get a\n",
    "wrapper class that forwards all method calls to the real instance using\n",
    "colocated Python."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "Ikb4Hh5iDB7Z"
   },
   "outputs": [],
   "source": [
    "class Adder:\n",
    "\n",
    "  def __init__(self, increment):\n",
    "    print('Adder created')\n",
    "    self.increment = increment\n",
    "\n",
    "  def __del__(self):\n",
    "    print('Adder destroyed')\n",
    "\n",
    "  def add(self, x):\n",
    "    return x + self.increment\n",
    "\n",
    "\n",
    "Adder = colocated_python.colocated_python_class(Adder)\n",
    "adder = Adder(1)\n",
    "x = jax.device_put(1, cpu_sharding)\n",
    "y = adder.add(x)\n",
    "print(y)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "t4i192BGDCw8"
   },
   "source": [
    "When the wrapper class instance is destroyed, the real instance is destroyed as\n",
    "well. Note that this destruction will be asynchronous."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "j5g-NNYFDDln"
   },
   "outputs": [],
   "source": [
    "del adder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "UfQTjAu9DEV-"
   },
   "source": [
    "There are a few important semantic differences between colocated Python and\n",
    "normal Python.\n",
    "\n",
    "* A colocated Python class instance is created only on the hosts associated with\n",
    "  the devices when any non-constructor method is called for the first time. In\n",
    "  the above example, `Adder(1)` captures the constructor arguments\n",
    "  `1`, but the actual constructor call `Adder(1)` on the hosts\n",
    "  happens only when the first `adder.add(x)` call is made. This is because\n",
    "  it is unknown what hosts the `Adder` instance should be created on until\n",
    "  there is a call to its method.\n",
    "\n",
    "* If the method(s) of the same wrapper class is called with inputs with\n",
    "  different devices, the real instance may be created at different times on\n",
    "  different hosts. If the first method call used CPU devices on host A, and the\n",
    "  second method call used CPU devices on host B, the real instance will be\n",
    "  created on host A during the first method call, and then on host B during the\n",
    "  second method call.\n",
    "\n",
    "* The methods of colocated Python classes are not yet specializable. The support\n",
    "  will be added in the future."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YOsb92ChDFQd"
   },
   "source": [
    "## Execution order and concurrency\n",
    "\n",
    "Colocated Python provides \"program order\" execution. Even if colocated Python\n",
    "calls may be asynchronous (returning output JAX arrays without blocking), the\n",
    "calls will be executed in the same order as the order the calls are made in the\n",
    "user program. Thus, by default, colocated Python calls are sequentially\n",
    "executed.\n",
    "\n",
    "Several use cases of colocated Python will benefit from concurrent execution.\n",
    "For example, one colocated Python call may take long time to return because it\n",
    "may be doing expensive file reads, while another colocated Python call may need\n",
    "to do file writes that are independent from the first one. This situation could\n",
    "expect two calls to run concurrently without blocking each other.\n",
    "\n",
    "Colocated Python provides concurrent execution if colocated Python calls are\n",
    "made from different threads. For example, the below example would make two\n",
    "colocated Python calls to run concurrently."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "l0L1-HaGDGHo"
   },
   "outputs": [],
   "source": [
    "import concurrent.futures\n",
    "import time\n",
    "\n",
    "\n",
    "def f(x):\n",
    "  time.sleep(1)\n",
    "  return x + 1\n",
    "\n",
    "\n",
    "f = colocated_python.colocated_python(f)\n",
    "f = f.specialize(out_specs_fn=lambda x: x)  # Calls will be asynchronous.\n",
    "\n",
    "with concurrent.futures.ThreadPoolExecutor(2) as executor:\n",
    "  fut1 = executor.submit(f, x)\n",
    "  fut2 = executor.submit(f, x)\n",
    "  # Will finish in approximately 1 second instead of 2 seconds.\n",
    "  jax.block_until_ready([fut1.result(), fut2.result()])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lRYja4_pDHFm"
   },
   "source": [
    "While calls from different threads run concurrently, on each thread, program\n",
    "ordering will continue to apply."
   ]
  }
 ],
 "metadata": {
  "colab": {
   "private_outputs": true
  },
  "jupytext": {
   "formats": "ipynb,md:myst",
   "main_language": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
